{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07c14a1-26c6-4c73-96ae-e754a55f7b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dendron import *\n",
    "from dendron.actions.image_lm_action import ImageLMActionConfig, ImageLMAction\n",
    "\n",
    "import io\n",
    "from PIL import Image\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eb0a024-a15c-4cb0-be45-b4bf899f54d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = ImageLMActionConfig(load_in_4bit=True,\n",
    "                          max_new_tokens=70,\n",
    "                          do_sample=True,\n",
    "                          top_p=0.95,\n",
    "                          device=\"cuda\",\n",
    "                          use_flash_attn_2=True,\n",
    "                          model_name=\"llava-hf/llava-1.5-13b-hf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b875f3f4-6616-4c34-bda3-8e61c1ac03d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "node = ImageLMAction(\"ilm_action\", cfg)\n",
    "tree = BehaviorTree(\"image-ml-tree\", node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e7ea4ec-b3ed-45fe-b40d-3191e607afc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = Image.open(\"./puppy.jpg\")\n",
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db421711-b151-4d28-860f-bfc14d979314",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"USER: <image>\\n{query}\\nASSISTANT:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6802730-8c06-440a-a4f7-efd96cc5c561",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def trunc(node):\n",
    "    out_str = node.blackboard[node.output_key]\n",
    "    m = re.search(r'ASSISTANT:(.*)', out_str, re.DOTALL)\n",
    "    node.blackboard[node.output_key] = m.group(1).strip()\n",
    "\n",
    "tree.root.add_post_tick(trunc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d2a4d2d-586d-426b-b7fe-9e5ae5f1a44e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree.blackboard[\"text_in\"] = prompt_template.format(query=\"What kind of animal is this? Give a comprehensive answer.\")\n",
    "tree.blackboard[\"image_in\"] = img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c4078c-838a-4d52-af4d-8c9c3c72f76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "status = tree.tick_once()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55d68937-b50c-4551-b82b-136cfb75adbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "if status == NodeStatus.SUCCESS:\n",
    "    print(tree.blackboard[\"out\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f39210-f585-4225-95f1-b1abafe26227",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree.pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "551f2075-25dc-4262-b760-a3182d9ca79b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f908e2e3-13de-4caa-8c9c-931fc0707023",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
